The objective of this notebook is to test flow matching on simple 2D datasets, by testing:
# Navigate to the root of the project and import necessary libraries
from hydra import initialize, compose
initialize(config_path="../confs", job_name="notebook")
import sys
import os
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
sys.path.append(os.path.abspath("../src"))
# We change the working directory to the root of the project
# Run this only once
root_path = Path.cwd().parent
os.chdir(root_path)
sys.path.append(root_path / "src")
/tmp/ipykernel_7185/1263814139.py:3: UserWarning: The version_base parameter is not specified. Please specify a compatability version level, or None. Will assume defaults for version 1.1 initialize(config_path="../confs", job_name="notebook")
from experiments.datasets.datasets import GMM2GMM, get_first_example_dataset, get_second_example_dataset
from torch.utils.data import DataLoader
# First example: GMM with 2 components -> GMM with 3 components
# Second example: Gaussian -> GMM with 5 components
n_points = 500000
# dataset = get_first_example_dataset(n_points)
dataset = get_second_example_dataset(n_points)
loader = DataLoader(dataset, batch_size=512, shuffle=True)
# Let's visualize the first batch
from src.experiments.visualization.plots import *
batch = next(iter(loader))
x0, x1, y = batch
scatter_points(x0, x1)
We use Hydra to specify our models in .yaml configuration files, to simplify the loading process, experiments, and configuration management:
from hydra.utils import instantiate
from omegaconf import OmegaConf
from src.flows.types import Predicts
os.environ["HYDRA_FULL_ERROR"] = "1"
cfg = compose(config_name="flow_model/toy_flow")
flow_model = instantiate(cfg.flow_model)
/home/mathis/anaconda3/lib/python3.12/site-packages/hydra/_internal/defaults_list.py:251: UserWarning: In 'flow_model/toy_flow': Defaults list is missing `_self_`. See https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/default_composition_order for more information warnings.warn(msg, UserWarning) /home/mathis/anaconda3/lib/python3.12/site-packages/pytorch_lightning/utilities/parsing.py:209: Attribute 'path' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['path'])`.
Here is how a configuration file for a Flow is usually structured.
_target_ points to an object that we would like to instanciate. One of the main strengths of Hydra is that we can make our configurations inherit from sub-configurations defined in other files and define defaults for the values. Moreover, these .yaml configuration files are convenient because they are easy to export and therefore convenient for logging experiments.
Note that it is important to have _partial_=True for the optimizer_cfg and the scheduler_cfg when specified.
# import the function to display markdown
from IPython.display import display, Markdown
# Display the model configuration
display(Markdown(f"```yaml\n{OmegaConf.to_yaml(cfg)}\n```"))
flow_model:
_target_: src.flows.flow.Flow
path:
_target_: src.flows.path.AffinePath
scheduler:
_target_: src.flows.schedulers.OTScheduler
loss_fn:
_target_: src.flows.losses.MSEFlowMatchingLoss
cfg:
_target_: src.flows.types.FlowConfig
predicts: x_0
optimizer_cfg:
_target_: torch.optim.Adam
_partial_: true
lr: 0.0005
model:
_target_: src.experiments.models.models.TimeConditionedMLP
x_dim: 2
output_dim: 2
num_fourier_bands: 4
hidden_dim: 64
n_layers: 4
loader = DataLoader(dataset, batch_size=80, shuffle=True)
x0, x1, y = next(iter(loader))
t, x_t = flow_model.path.sample(x_0=x0, x_1=x1)
target_vectors = flow_model.path.target_velocity(t, x0, x1)
scatter_points_with_velocity(x0, x1, x_t, target_vectors)
For training models, we resort to PyTorch-Lightning, which fully automates the training loop. This module is very powerful, as it supports many options, such as:
So let's load our Trainer class:
trainer_cfg = compose(config_name="trainer/toy_trainer")
# Display the trainer configuration
display(Markdown(f"```yaml\n{OmegaConf.to_yaml(trainer_cfg)}\n```"))
trainer = instantiate(trainer_cfg.trainer)
/home/mathis/anaconda3/lib/python3.12/site-packages/hydra/_internal/defaults_list.py:251: UserWarning: In 'trainer/toy_trainer': Defaults list is missing `_self_`. See https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/default_composition_order for more information warnings.warn(msg, UserWarning)
trainer:
_target_: pytorch_lightning.Trainer
log_every_n_steps: 10
num_sanity_val_steps: 2
check_val_every_n_epoch: 1
accelerator: gpu
devices: 1
callbacks:
- _target_: pytorch_lightning.callbacks.TQDMProgressBar
refresh_rate: 100
leave: true
max_epochs: 3
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry. GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores HPU available: False, using: 0 HPUs /home/mathis/anaconda3/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
And now, we fit our model:
import logging
logging.basicConfig(level=logging.INFO)
loader = DataLoader(dataset, batch_size=1000, shuffle=True)
# train the model
trainer.fit(flow_model, train_dataloaders=loader)
/home/mathis/anaconda3/lib/python3.12/site-packages/pytorch_lightning/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop. LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] | Name | Type | Params | Mode ----------------------------------------------------- 0 | path | AffinePath | 0 | train 1 | model | TimeConditionedMLP | 13.3 K | train ----------------------------------------------------- 13.3 K Trainable params 0 Non-trainable params 13.3 K Total params 0.053 Total estimated model params size (MB) 13 Modules in train mode 0 Modules in eval mode /home/mathis/anaconda3/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
Training: | | 0/? [00:00<?, ?it/s]
Training: | | 0/? [00:00<?, ?it/s]
Training: | | 0/? [00:00<?, ?it/s]
Training: | | 0/? [00:00<?, ?it/s]
`Trainer.fit` stopped: `max_epochs=3` reached.
flow_model.eval()
flow_model.path.eval()
t, x_t= flow_model.path.sample(x_0=x0, x_1=x1)
# get u_theta(t, x_t)
v = flow_model.estimated_velocity(t, x_t).detach()
scatter_points_with_velocity(x0, x1, x_t, v)
import torch
def estimate_score(t, x_t):
if (t==0).any() or (t==1).any():
return torch.zeros_like(x_t)+0.1
v = flow_model.estimated_velocity(t, x_t)
score= flow_model.path.convert_parameterization(t, x_t, v, "v", "score")
normalized_score = score / torch.norm(score, dim=1, keepdim=True)
return normalized_score
def estimate_velocity(t, x_t, norm = True):
v = flow_model.estimated_velocity(t, x_t)
return v / torch.norm(v, dim=1, keepdim=True).max() if norm else v
ani = animate_estimated_velocity(x0, x1, estimate_velocity, device='cpu')
HTML(ani.to_jshtml())
INFO:matplotlib.animation:Animation.save using <class 'matplotlib.animation.HTMLWriter'>